Skip to content

Initialize Qwen3.5 mutable buffers during export#17801

Closed
Phineas1500 wants to merge 2 commits intopytorch:mainfrom
Phineas1500:qwen3_5_phase2
Closed

Initialize Qwen3.5 mutable buffers during export#17801
Phineas1500 wants to merge 2 commits intopytorch:mainfrom
Phineas1500:qwen3_5_phase2

Conversation

@Phineas1500
Copy link
Contributor

@Phineas1500 Phineas1500 commented Mar 2, 2026

Summary

  • factor mutable-buffer pass selection into _get_additional_export_passes
  • keep existing torchtune behavior (kv_cache_pos)
  • add Qwen3.5 buffer initialization patterns for export: k_cache, v_cache, conv_state, recurrent_state
  • wire this helper into both single-method and multimethod export paths
  • add unit tests covering pass selection for Qwen3.5, torchtune, and llama3 baseline

Why

Qwen3.5 uses internal mutable state (KV + DeltaNet recurrent/conv buffers). Initializing these buffers at export time avoids uninitialized mutable-buffer state and makes startup behavior deterministic.

Test Plan

  • PYTHONPATH=src:/Users/sriram/Documents/executorch/third-party/ao:/tmp/qwen35_test/pytok pytest -q examples/models/llama/tests/test_export_llama_lib.py -k "qwen3_5_mutable_buffer_passes or torchtune_mutable_buffer_passes or llama3_has_no_extra_mutable_buffer_passes"
  • PYTHONPATH=src:/Users/sriram/Documents/executorch/third-party/ao:/tmp/qwen35_test/pytok pytest -q examples/models/llama/tests/test_qwen3_5_attention.py examples/models/qwen3_5/tests/test_convert_weights.py

Stacking

Copilot AI review requested due to automatic review settings March 2, 2026 23:50
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 2, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17801

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 4 Unrelated Failures

As of commit 4826903 with merge base 9d413ac (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 2, 2026
@Phineas1500
Copy link
Contributor Author

@pytorchbot label "release notes: none"

@pytorch-bot pytorch-bot bot added the release notes: none Do not include this in the release notes label Mar 2, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Qwen3.5 support to the Llama export pipeline with deterministic initialization of Qwen3.5’s internal mutable buffers (KV cache + DeltaNet recurrent/conv state) during export, and introduces the Qwen3.5 attention implementations/configs needed to run/export the hybrid layer layout.

Changes:

  • Add Qwen3.5 model types/configs and HF weight conversion utilities for ExecuTorch “meta” format.
  • Implement Qwen3.5 hybrid attention blocks (full attention + Gated DeltaNet linear attention) and wire hybrid layer construction into the Llama transformer.
  • Factor/export additional mutable-buffer initialization pass selection (torchtune + Qwen3.5) into a shared helper and add unit tests for pass selection and attention state reset.

Reviewed changes

Copilot reviewed 21 out of 21 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
extension/llm/export/config/llm_config.py Adds Qwen3.5 model types to the export config enum.
examples/models/qwen3_5/tests/test_convert_weights.py Unit test for Qwen3.5 HF→meta key mapping.
examples/models/qwen3_5/tests/init.py Package marker/license header for Qwen3.5 tests.
examples/models/qwen3_5/convert_weights.py Implements Qwen3.5 checkpoint loading and key conversion (incl. legacy packed tensor splitting).
examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml Adds an fp32/static-shape XNNPACK export config for Qwen3.5.
examples/models/qwen3_5/config/4b_config.json Adds model args for Qwen3.5 4B (hybrid layer_types etc.).
examples/models/qwen3_5/config/2b_config.json Adds model args for Qwen3.5 2B.
examples/models/qwen3_5/config/0_8b_config.json Adds model args for Qwen3.5 0.8B.
examples/models/qwen3_5/init.py Adds a Qwen3.5 model entrypoint (lazy subclass of Llama2Model) and exports convert_weights.
examples/models/qwen3_5/README.md Documents export/run instructions for Qwen3.5 models.
examples/models/qwen3_5/BUCK Adds Buck target for the Qwen3.5 Python library + deps.
examples/models/llama/tests/test_qwen3_5_attention.py Adds tests for Qwen3.5 full-attn shape and DeltaNet state reset behavior.
examples/models/llama/tests/test_export_llama_lib.py Adds tests covering export-pass selection for Qwen3.5/torchtune/llama3.
examples/models/llama/tests/BUCK Registers the new Qwen3.5 attention unittest target.
examples/models/llama/norm.py Extends RMSNorm to support Qwen3.5 “(1 + weight)” scaling.
examples/models/llama/model_args.py Adds Qwen3.5 linear-attention dims + RMSNorm scaling flag to ModelArgs with defaults.
examples/models/llama/llama_transformer.py Wires RMSNorm scaling flag and constructs DeltaNet layers when layer_types specify linear_attention.
examples/models/llama/export_llama_lib.py Adds Qwen3.5 model ids, hooks Qwen3.5 weight conversion, and factors mutable-buffer init pass selection into helper.
examples/models/llama/attention.py Adds Qwen3.5 full attention and Gated DeltaNet attention implementations (+ required buffers).
examples/models/llama/init.py Switches llama package export to lazy import pattern for Llama2Model.
examples/models/BUCK Adds the Qwen3.5 model package to the umbrella models BUCK target.
Comments suppressed due to low confidence (1)

examples/models/llama/norm.py:60

  • RMSNorm currently returns output * self.weight when add_unit_offset is False. Since output is cast back to type_as(x) but self.weight stays fp32, this multiplication will promote the result to fp32 for fp16/bf16 inputs. The new add_unit_offset branch explicitly casts the weight to type_as(x), so the dtype behavior is now inconsistent between the two paths. Consider casting self.weight to type_as(x) (or otherwise ensuring the output dtype matches the input) in the non-offset path as well.
            return output * (1.0 + self.weight.float()).type_as(x)
        return output * self.weight


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

raise ValueError(
f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}"
)
qkv, z = torch.split(value, [conv_dim, value_dim], dim=0)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key_dim is computed when splitting legacy packed in_proj_qkvz.weight but is never used afterward. Please remove it or use it for an explicit shape validation to avoid dead code.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 3, 2026 00:39
@Phineas1500
Copy link
Contributor Author

Validated export and runtime with the XNNPACK recipe.

Set max_seq_len and max_context_len to 128, generated the .pte, and ran executorch.examples.models.llama.runner.native with a multi-token prompt.

The model currently uses static-shape export in this path, but I added sequential token prefill.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 25 out of 25 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +100 to +114
try:
return self.forward(
tokens=torch.tensor(
[prompt_tokens], dtype=torch.long, device=self.device
),
input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device),
)
except RuntimeError:
# Some exported models use a static single-token shape for kv-cache mode.
# Fall back to sequential token prefill so multi-token prompts still work.
if self.enable_dynamic_shape or len(prompt_tokens) <= 1:
raise

return self._sequential_kv_prefill(prompt_tokens, pos_base)

Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _prefill_with_kv_cache, the early return for not self.enable_dynamic_shape and len(prompt_tokens) > 1 means the subsequent try/except RuntimeError never executes for the “static single-token shape” case described below. As written, the fallback logic is effectively dead code for the static-shape scenario; consider simplifying to a single path (either always sequential when static, or always try batched then fall back).

Suggested change
try:
return self.forward(
tokens=torch.tensor(
[prompt_tokens], dtype=torch.long, device=self.device
),
input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device),
)
except RuntimeError:
# Some exported models use a static single-token shape for kv-cache mode.
# Fall back to sequential token prefill so multi-token prompts still work.
if self.enable_dynamic_shape or len(prompt_tokens) <= 1:
raise
return self._sequential_kv_prefill(prompt_tokens, pos_base)
return self.forward(
tokens=torch.tensor(
[prompt_tokens], dtype=torch.long, device=self.device
),
input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device),
)

Copilot uses AI. Check for mistakes.
Comment on lines +110 to +112
if self.enable_dynamic_shape or len(prompt_tokens) <= 1:
raise

Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The except RuntimeError fallback is currently unreachable when self.enable_dynamic_shape is True (the default): the handler re-raises whenever self.enable_dynamic_shape is true, so a static exported model that throws on batched prefill (and lacks the enable_dynamic_shape metadata method) will fail instead of falling back to sequential prefill. Consider falling back based on len(prompt_tokens) > 1 (and/or the specific error) rather than the enable_dynamic_shape flag, or updating the flag when the batched call fails.

Suggested change
if self.enable_dynamic_shape or len(prompt_tokens) <= 1:
raise
#
# If the batched prefill fails for a multi-token prompt, disable dynamic
# shape support and retry using sequential prefill. For single-token
# prompts, propagate the error.
if len(prompt_tokens) <= 1:
raise
# Avoid retrying batched dynamic-shape prefill after a failure.
self.enable_dynamic_shape = False

Copilot uses AI. Check for mistakes.
output = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
return output * (1.0 + self.weight.float()).type_as(x)
return output * self.weight
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RMSNorm.forward returns output * self.weight when add_unit_offset is false, which will promote dtypes (e.g., fp16 input → fp32 output) because self.weight is fp32. In the new add_unit_offset branch you explicitly cast the scale to type_as(x), so the output dtype now depends on the flag. Consider casting self.weight (or the final product) to type_as(x) in both branches to keep output dtype consistent with the input.

Suggested change
return output * self.weight
return output * self.weight.type_as(x)

Copilot uses AI. Check for mistakes.
Comment on lines +47 to +53
try:
self.enable_dynamic_shape = bool(
self.model.run_method("enable_dynamic_shape")[0]
)
except Exception:
# Keep default behavior when metadata method is unavailable.
pass
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching a bare Exception around run_method("enable_dynamic_shape") can also hide real runtime issues (e.g., model load/ABI problems) and silently keep enable_dynamic_shape=True. It would be safer to catch the specific “method missing”/runtime exceptions raised by run_method (and optionally log at debug level) so unexpected failures don’t get swallowed.

Copilot uses AI. Check for mistakes.
Comment on lines +111 to +123
try:
new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META)
except Exception:
# Ignore non-text weights and training-only extras (e.g., MTP).
if (
key.startswith("mtp.")
or key.startswith("model.visual.")
or ".vision_" in key
or key.startswith("visual.")
):
continue
# Ignore unsupported keys that are not required by the export model.
continue
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The except Exception: ... continue around get_mapped_key will silently drop any unexpected keys (including genuinely required text weights if the mapping is incomplete or the checkpoint format changes). This makes conversion failures hard to detect. Consider only ignoring a well-defined allowlist of optional prefixes (vision/MTP/etc.) and re-raising for other model.* keys, or at least logging the first few unmapped keys at warning level.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 3, 2026 22:12
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 25 out of 25 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

):
if _should_ignore_unmapped_key(key, normalized_key):
continue
continue
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the non-language-model key branch, _should_ignore_unmapped_key(...) currently has no effect because the code continues unconditionally whether the key is ignored or not. This makes it easy to silently drop unexpected checkpoint keys. Consider either (a) raising for non-ignored keys here, or (b) removing the ignore-check entirely if the intent is to ignore all non-text keys.

Suggested change
continue
raise ValueError(
"Unexpected non-language-model checkpoint key not mapped for "
f"Qwen3.5 export: {key}"
)

Copilot uses AI. Check for mistakes.
"model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attention.in_proj_b.weight",
"model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attention.in_proj_a.weight",
"model.layers.{}.linear_attn.conv1d.weight": "layers.{}.attention.conv1d.weight",
"model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias",
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The converter maps model.layers.*.linear_attn.conv1d.bias to layers.*.attention.conv1d.bias, but the corresponding ExecuTorch module (AttentionGatedDeltaNet) constructs conv1d with bias=False (no conv1d.bias parameter). If HF checkpoints can include this bias, it would be better to ignore/drop it during conversion (or flip the module to bias=True) to avoid carrying an unused tensor and relying on strict=False loads.

Suggested change
"model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias",

Copilot uses AI. Check for mistakes.
Comment on lines +137 to +145
# Legacy packed tensors (older checkpoints):
# in_proj_qkvz -> split into in_proj_qkv and in_proj_z
# in_proj_ba -> split into in_proj_b and in_proj_a
if normalized_key.endswith(".linear_attn.in_proj_qkvz.weight"):
pending_qkvz[normalized_key] = value
continue
if normalized_key.endswith(".linear_attn.in_proj_ba.weight"):
pending_ba[normalized_key] = value
continue
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The legacy packed-key handling (in_proj_qkvz / in_proj_ba) introduces non-trivial splitting logic that isn't covered by the new unit tests. Add a test case that includes a packed key plus the required out_proj.weight so the split shapes/keys (and error paths) are exercised.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 5, 2026 05:40
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 25 out of 25 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings March 5, 2026 21:23
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 25 out of 25 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +138 to +156
def _get_additional_export_passes(model_class: str) -> List[InitializedMutableBufferPass]:
patterns = []

if model_class in TORCHTUNE_DEFINED_MODELS:
patterns.append("kv_cache_pos")

# Qwen3.5 uses internal mutable buffers for both the hybrid KV path and
# DeltaNet recurrent/conv states.
if model_class.startswith("qwen3_5"):
patterns.extend(
[
"k_cache",
"v_cache",
"conv_state",
"recurrent_state",
]
)

return [InitializedMutableBufferPass(patterns)] if patterns else []
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InitializedMutableBufferPass causes matched mutated buffers to be serialized with their initial values. Initializing large buffers like k_cache/v_cache will therefore increase the exported .pte size (and potentially load time / memory pressure) by the full KV-cache tensor sizes. If this is expected, consider adding a config flag (or model-class-specific opt-out) so callers can choose determinism vs artifact size, and document the expected size impact for Qwen3.5 exports.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 5, 2026 22:13
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 25 out of 25 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Phineas1500
Copy link
Contributor Author

@lucylq rebased

Copilot AI review requested due to automatic review settings March 6, 2026 20:57
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +150 to +154
if model_class.startswith("qwen3_5"):
patterns.extend(
[
"k_cache",
"v_cache",
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initializing KV cache buffers via InitializedMutableBufferPass will cause their full tensor contents to be serialized into the .pte (the emitter treats et_init_buffer+mutable_buffer as const). For k_cache/v_cache this can be extremely large (per-layer [B, H, S, D]) and may blow up export size and load time. Consider avoiding initializing the full KV caches at export (e.g., only init the small state buffers like conv_state/recurrent_state, or add a runtime/cache-reset path that deterministically zeros these buffers without serializing them).

Suggested change
if model_class.startswith("qwen3_5"):
patterns.extend(
[
"k_cache",
"v_cache",
# Avoid initializing large KV cache buffers (k_cache/v_cache) here, since
# InitializedMutableBufferPass would serialize their full contents into
# the exported artifact, significantly increasing size and load time.
if model_class.startswith("qwen3_5"):
patterns.extend(
[

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. @Phineas1500 does qwen3.5 require initial state for the kv-cache, conv_state and recurrent_state?

The InitializedMutableBufferPass is only required for mutable buffers with initial state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like ~5mb size increase from including initial state. Not too sure why - was expecting a bit more.

-rw-r--r-- 1 lfq users 4032780800 Mar  6 14:29 qwen3_5_0_8b_fp32_no_init.pte
-rw-r--r-- 1 lfq users 4038122240 Mar  5 11:05 qwen3_5_0_8b_fp32.pte

Output is the same with temp=0

(executorch) [lfq@devvm311.ldc0 /data/users/lfq/executorch (qwen3_5_phase2)]$ python -m executorch.examples
.models.llama.runner.native --model qwen3_5_0_8b --pte qwen3_5_0_8b_fp32_no_init.pte --tokenizer ~/.cache/h
uggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17/tokenizer.json
 --tokenizer_config ~/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9c
f38778875588b17/tokenizer_config.json --params examples/models/qwen3_5/config/0_8b_config.json --prompt "<|
im_start|>user\nHello, what's 15% of 80?<|im_end|>\n<|im_start|>assistant\n" --max_len 128 -kv --temperatur
e 0
I tokenizers:regex.cpp:27] Registering override fallback regex
Warning - given vocab_size in params is unequal to tokenizer vocab size.
[cpuinfo_utils.cpp:71] Reading file /sys/devices/soc0/image_version
[cpuinfo_utils.cpp:87] Failed to open midr file /sys/devices/soc0/image_version
[cpuinfo_utils.cpp:100] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1
[cpuinfo_utils.cpp:109] Failed to open midr file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1
[cpuinfo_utils.cpp:125] CPU info and manual query on # of cpus dont match.
<think>

</think>

To find 15% of 80, you can multiply 80 by 0.15:

$$80 \times 0.15 = 12$$

So, **15% of 80 is 12**.

Prefill time: 15.471495151519775
Generation tok/s: 2.097784149345492
Response: [248068, 271, 248069, 271, 1206, 1423, 220, 16, 20, 4, 314, 220, 23, 15, 11, 488, 628, 29283, 220, 23, 15, 539, 220, 15, 13, 16, 20, 25, 271, 13682, 23, 15, 1088, 14695, 220, 15, 13, 16, 20, 283, 220, 16, 17, 13682, 271, 4272, 11, 2972, 16, 20, 4, 314, 220, 23, 15, 369, 220, 16, 17, 159034, 248046]

Seems like the state is already zeroed here?
https://github.com/pytorch/executorch/blob/main/examples/models/llama/attention.py#L720

Comment on lines +153 to +154
"k_cache",
"v_cache",
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InitializedMutableBufferPass matches patterns by substring. Using "k_cache"/"v_cache" here will also match other buffer names like "k_cache_scales", "k_cache_zero_points", or "past_k_caches_*" if present in the exported graph, potentially initializing/serializing more (large) buffers than intended. If you only mean the primary caches, consider narrowing the patterns to something less collision-prone (e.g., include a delimiter or full buffer name) or splitting by known FQN fragments.

Suggested change
"k_cache",
"v_cache",
".k_cache",
".v_cache",

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@lucylq lucylq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this PR is necessary - seems like the recurrent state is initialized in the code _maybe_reset_state. I'm not sure kv-cache needs to be initialized. We get some binary size savings by not storing initial-value 0s in the .pte file but it's not as much as I expected only 5mb.

Let me know what you think @Phineas1500 . I added a longer comment in the code.

@Phineas1500
Copy link
Contributor Author

I am not sure if this PR is necessary - seems like the recurrent state is initialized in the code _maybe_reset_state. I'm not sure kv-cache needs to be initialized. We get some binary size savings by not storing initial-value 0s in the .pte file but it's not as much as I expected only 5mb.

Let me know what you think @Phineas1500 . I added a longer comment in the code.

@lucylq I agree with you, the operations are unnecessary. I think this PR can be closed, since nothing else in it is important.

I think I'm going to make another PR addressing this concern #17800 (comment)

It'll improve performance if I replace the loop in attention.py that runs once per token (_recurrent_gated_delta_rule) with a custom op.

Let me know if you think that sounds good or not.

@lucylq
Copy link
Contributor

lucylq commented Mar 6, 2026

I am not sure if this PR is necessary - seems like the recurrent state is initialized in the code _maybe_reset_state. I'm not sure kv-cache needs to be initialized. We get some binary size savings by not storing initial-value 0s in the .pte file but it's not as much as I expected only 5mb.
Let me know what you think @Phineas1500 . I added a longer comment in the code.

@lucylq I agree with you, the operations are unnecessary. I think this PR can be closed, since nothing else in it is important.

I think I'm going to make another PR addressing this concern #17800 (comment)

It'll improve performance if I replace the loop in attention.py that runs once per token (_recurrent_gated_delta_rule) with a custom op.

Let me know if you think that sounds good or not.

@Phineas1500 sounds good, this PR can be closed then.

Yes it would be great if you could optimize the recurrence! Do you also have time to work on quantization - perhaps 8da4w with xnnpack? (if not, I'll take a look).

@Phineas1500
Copy link
Contributor Author

Yes that would be great! Do you also have time to work on quantization - perhaps 8da4w with xnnpack? (if not, I'll take a look).

@lucylq sure, happy to look at both. should i create a separate PR for each?

@lucylq
Copy link
Contributor

lucylq commented Mar 6, 2026

Yes that would be great! Do you also have time to work on quantization - perhaps 8da4w with xnnpack? (if not, I'll take a look).

@lucylq sure, happy to look at both. should i create a separate PR for each?

yes please! Thanks so much, appreciate it 🙏

@Phineas1500 Phineas1500 closed this Mar 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: none Do not include this in the release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants